import numpy as np
import pandas as pd
import statsmodels.api as sm
import math
import sys

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
        non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
        est =  black_aud_rate - non_black_aud_rate
        return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
        seMultiplier=getSEMultiplier(dataset,pbvarb)
        #seReg = regEstimate(dataset,pbvarb,outcome)[1]
        seChen = seReg*seMultiplier
        return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))


### Operational audits and predicted race probabilitiess
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final_new_eic.csv")
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]

# define duplicate child claim audits and non-duplicate child claim audits
dataBISG['code_58_ind'] = np.where(dataBISG.project.str.contains('0058'), 1, 0) # keep
dataBISG['code_59_ind'] = np.where(dataBISG.project.str.contains('0059'), 1, 0) # non-eitc but keep
dataBISG['code_70_ind'] = np.where(dataBISG.project.str.contains('0070'), 1, 0) # none here
dataBISG['code_97_ind'] = np.where(dataBISG.project.str.contains('0097'), 1, 0) # non-eitc but keep
dataBISG['code_98_ind'] = np.where(dataBISG.project.str.contains('0098'), 1, 0) # unsure but only 1
dataBISG['code_587_ind'] = np.where(dataBISG.project.str.contains('0587'), 1, 0) # do not keep
dataBISG['code_652_ind'] = np.where(dataBISG.project.str.contains('0652'), 1, 0) # 'definitely'
dataBISG['code_1309_ind'] = np.where(dataBISG.project.str.contains('1309'), 1, 0) # none here

dataBISG['same_child_aud_ind'] = dataBISG['code_58_ind'] + dataBISG['code_59_ind'] + dataBISG['code_70_ind'] + dataBISG['code_97_ind'] + dataBISG['code_98_ind'] + dataBISG['code_652_ind'] + dataBISG['code_1309_ind'] 
dataBISG['same_child_aud_ind'] = np.where(dataBISG['same_child_aud_ind'] > 1, 1, dataBISG['same_child_aud_ind'])
dataBISG['non_same_child_aud_ind'] = np.where((dataBISG['aud_no_research_audits'] == 1) & (dataBISG['same_child_aud_ind'] == 0), 1, 0)

# write out results
sys.stdout = open("/REDACTED/disparity_results_non_same_child.txt", "w")

# Full population
linDisparity = regEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'non_same_child_aud_ind')
probDisparity = chenEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'non_same_child_aud_ind')
linSE, probSE = getSEs(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'non_same_child_aud_ind', seReg = float(linDisparity[1]))

full_pop_overall_lin = linDisparity[0]
full_pop_overall_prob = probDisparity[0]

# EITC
eitc_pop = dataBISG.loc[dataBISG['isEIC'] == 1]

linDisparity_eic = regEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'non_same_child_aud_ind')
probDisparity_eic = chenEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'non_same_child_aud_ind')
linSE_eic, probSE_eic = getSEs(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'non_same_child_aud_ind', seReg = float(linDisparity_eic[1]))

eitc_overall_lin = linDisparity_eic[0]
eitc_overall_prob = probDisparity_eic[0]

# Non-EITC
non_eitc_pop = dataBISG.loc[dataBISG['isEIC'] == 0]

linDisparity_non_eic = regEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'non_same_child_aud_ind')
probDisparity_non_eic = chenEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'non_same_child_aud_ind')
linSE_non_eic, probSE_non_eic = getSEs(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'non_same_child_aud_ind', seReg = float(linDisparity_non_eic[1]))

non_eitc_overall_lin = linDisparity_non_eic[0]
non_eitc_overall_prob = probDisparity_non_eic[0]

print("Full population \n")
print("Overall Linear Disparity: "+str(full_pop_overall_lin))
print("Linear Standard Error: "+str(linSE))
print("Overall Probabilistic Disparity: "+str(full_pop_overall_prob))
print("Probabilistic Standard Error: "+str(probSE)+"\n")


print("EITC population \n")
print("Overall Linear Disparity: "+str(eitc_overall_lin))
print("Linear Standard Error: "+str(linSE_eic))
print("Overall Probabilistic Disparity: "+str(eitc_overall_prob))
print("Probabilistic Standard Error: "+str(probSE_eic)+"\n")


print("Non EITC population \n")
print("Overall Linear Disparity: "+str(non_eitc_overall_lin))
print("Linear Standard Error: "+str(linSE_non_eic))
print("Overall Probabilistic Disparity: "+str(non_eitc_overall_prob))
print("Probabilistic Standard Error: "+str(probSE_non_eic)+"\n")

sys.stdout = sys.__stdout__
